import torch
from torch.utils.data import ConcatDataset, DataLoader, Dataset, Subset, random_split
import torchvision.transforms as transforms
from  .lda import Preprocess 
import numpy as np 
from torchvision.datasets import MNIST, CIFAR10, CIFAR100
import os
import gc
from transformers import MobileViTImageProcessor
 
class CustomDataset(Dataset):
    def __init__(self, feature_data, target_labels):
        self.feature_data = feature_data
        self.target_labels = target_labels

    def __len__(self):
        return len(self.feature_data)

    def __getitem__(self, idx):
        feature = self.feature_data[idx]
        target = self.target_labels[idx]
        return feature, target


class Mnist_data():
    def __init__(self, NUM_CLIENTS, IID, BATCH_SIZE):
        self.NUM_CLIENTS = NUM_CLIENTS
        self.BATCH_SIZE = BATCH_SIZE
        self.IID = IID
    
    def load_datasets(self):
        preprocess = Preprocess()
        iid = self.IID
        transform = transforms.Compose(
            [transforms.ToTensor()]
        )
        trainset = MNIST("/directory/for_data/", train=True, download=True, transform=transform)
        testset = MNIST("/directory/for_data/", train=False, download=True, transform=transform)

        data_new = torch.zeros(60000,28, 28)
        count = 0
        for image, _ in trainset:
            data_new[count] = image
            count +=1
            
        partition_size = len(trainset) // self.NUM_CLIENTS
        lengths = [partition_size] * self.NUM_CLIENTS
        if iid:
            datasets = random_split(trainset, lengths, torch.Generator().manual_seed(42))
            trainloaders = []
            valloaders = []
            for ds in datasets:
                len_val = len(ds) // 10  
                len_train = len(ds) - len_val
                lengths = [len_train, len_val]
                ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))
                trainloaders.append(DataLoader(ds_train, batch_size=self.BATCH_SIZE, shuffle=True))
                valloaders.append(DataLoader(ds_val, batch_size=self.BATCH_SIZE))
        else:
            flwr_trainset = (data_new, np.array(trainset.targets, dtype=np.int32))
            datasets,_ =  preprocess.create_lda_partitions(
                dataset=flwr_trainset,
                dirichlet_dist= None,
                num_partitions= self.NUM_CLIENTS,
                concentration=0.5,
                accept_imbalanced=True,
                seed=2,
            )
        # Create DataLoaders
            trainloaders = []
            valloaders = []
            for ds in datasets:
                len_val = len(ds[0]) // 10  
                len_train = len(ds[0]) - len_val
                lengths = [len_train, len_val]
                cd = CustomDataset(ds[0].astype(np.float32),ds[1])
                ds_train, ds_val = random_split(cd, lengths, torch.Generator().manual_seed(42))
                trainloaders.append(DataLoader(ds_train, batch_size=self.BATCH_SIZE, shuffle=True))
                valloaders.append(DataLoader(ds_val, batch_size=self.BATCH_SIZE))
        testloader = DataLoader(testset, batch_size=self.BATCH_SIZE)
        return trainloaders, valloaders, testloader


class RGBToBGR:
    def __call__(self, img):
        # Convert PIL Image to numpy array
        np_img = np.array(img)
        # Swap channels: RGB to BGR
        np_img = np_img[:, :, ::-1]
        # Convert numpy array back to PIL Image
        return transforms.functional.to_pil_image(np_img)


class Cifar_data():
    def __init__(self, NUM_CLIENTS, IID, BATCH_SIZE):
        self.NUM_CLIENTS = NUM_CLIENTS
        self.BATCH_SIZE = BATCH_SIZE
        self.IID = IID
    
    def load_datasets(self):
        preprocess = Preprocess()
        iid = self.IID
        transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        trainset = CIFAR10("/directory/for_data/", train=True, download=True, transform=transform)
        testset = CIFAR10("/directory/for_data/", train=False, download=True, transform=transform)

        data_new = torch.zeros(50000, 3, 224, 224)
        count = 0
        for image, _ in trainset:
            data_new[count] = image
            count +=1   
        partition_size = len(trainset) // self.NUM_CLIENTS
        lengths = [partition_size] * self.NUM_CLIENTS
        if iid:
            datasets = random_split(trainset, lengths, torch.Generator().manual_seed(42))
            trainloaders = []
            valloaders = []
            for ds in datasets:
                len_val = len(ds) // 10  
                len_train = len(ds) - len_val
                lengths = [len_train, len_val]
                ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))
                trainloaders.append(DataLoader(ds_train, batch_size=self.BATCH_SIZE, shuffle=True))
                valloaders.append(DataLoader(ds_val, batch_size=self.BATCH_SIZE))
        else:
            # Loses shape after create_lda_paritions
            flwr_trainset = (data_new, np.array(trainset.targets, dtype=np.int32))
            del data_new
            gc.collect()
            datasets,_ =  preprocess.create_lda_partitions(
                dataset=flwr_trainset,
                dirichlet_dist= None,
                num_partitions= self.NUM_CLIENTS,
                concentration=0.001,
                accept_imbalanced= False,
                seed= 12,
            )
            del flwr_trainset
            gc.collect()

            # Create DataLoaders
            trainloaders = []
            valloaders = []
            for ds in datasets:
                len_val = len(ds[0]) // 10  
                len_train = len(ds[0]) - len_val
                lengths = [len_train, len_val]
                cd = CustomDataset(ds[0].astype(np.float32),ds[1])
                ds_train, ds_val = random_split(cd, lengths, torch.Generator().manual_seed(42))
                trainloaders.append(DataLoader(ds_train, batch_size=self.BATCH_SIZE, shuffle=True))
                valloaders.append(DataLoader(ds_val, batch_size=self.BATCH_SIZE))
        testloader = DataLoader(testset, batch_size=self.BATCH_SIZE)
        del datasets
        gc.collect()
        return trainloaders, valloaders, testloader


class Cifar100_data():
    def __init__(self, NUM_CLIENTS, IID, BATCH_SIZE):
        self.NUM_CLIENTS = NUM_CLIENTS
        self.BATCH_SIZE = BATCH_SIZE
        self.IID = IID
    
    def load_datasets(self):
        preprocess = Preprocess()
        iid = self.IID
        transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        trainset = CIFAR100("/directory/for_data/", train=True, download=True, transform=transform)
        testset = CIFAR100("/directory/for_data/", train=False, download=True, transform=transform)

        data_new = torch.zeros(50000, 3, 224, 224)
        count = 0
        for image, _ in trainset:
            data_new[count] = image
            count +=1   
        partition_size = len(trainset) // self.NUM_CLIENTS
        lengths = [partition_size] * self.NUM_CLIENTS
        if iid:
            datasets = random_split(trainset, lengths, torch.Generator().manual_seed(42))
            trainloaders = []
            valloaders = []
            for ds in datasets:
                len_val = len(ds) // 10  
                len_train = len(ds) - len_val
                lengths = [len_train, len_val]
                ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))
                trainloaders.append(DataLoader(ds_train, batch_size=self.BATCH_SIZE, shuffle=True))
                valloaders.append(DataLoader(ds_val, batch_size=self.BATCH_SIZE))
        else:
            # Data shape lost after create_lda_paritions
            flwr_trainset = (data_new, np.array(trainset.targets, dtype=np.int32))
            del data_new
            gc.collect()
            datasets,_ =  preprocess.create_lda_partitions(
                dataset=flwr_trainset,
                dirichlet_dist= None,
                num_partitions= self.NUM_CLIENTS,
                concentration=0.001,
                accept_imbalanced= False,
                seed= 12,
            )
            del flwr_trainset
            gc.collect()

        # Create DataLoaders
            trainloaders = []
            valloaders = []
            for ds in datasets:
                len_val = len(ds[0]) // 10 
                len_train = len(ds[0]) - len_val
                lengths = [len_train, len_val]
                cd = CustomDataset(ds[0].astype(np.float32),ds[1])
                ds_train, ds_val = random_split(cd, lengths, torch.Generator().manual_seed(42))
                trainloaders.append(DataLoader(ds_train, batch_size=self.BATCH_SIZE, shuffle=True))
                valloaders.append(DataLoader(ds_val, batch_size=self.BATCH_SIZE))
        testloader = DataLoader(testset, batch_size=self.BATCH_SIZE)
        del datasets
        gc.collect()
        return trainloaders, valloaders, testloader


class WMT16Data:
    def __init__(self, NUM_CLIENTS,IID, BATCH_SIZE):
        self.NUM_CLIENTS = NUM_CLIENTS
        self.BATCH_SIZE = BATCH_SIZE
        self.IID = IID
    
    def load_datasets(self):
        dataset = load_dataset('wmt16', 'de-en')
        # Partition training set 
        partition_size = len(dataset['train']) // self.NUM_CLIENTS
        lengths = [partition_size] * self.NUM_CLIENTS
        datasets = random_split(dataset['train'], lengths)

        # Create DataLoaders
        trainloaders = []
        valloaders = []
        for ds in datasets:
            len_val = len(ds) // 10  
            len_train = len(ds) - len_val
            lengths = [len_train, len_val]
            ds_train, ds_val = random_split(ds, lengths)
            trainloaders.append(DataLoader(ds_train, batch_size=self.BATCH_SIZE, shuffle=True))
            valloaders.append(DataLoader(ds_val, batch_size=self.BATCH_SIZE))
        testloader = DataLoader(dataset['test'], batch_size=self.BATCH_SIZE)
        return trainloaders, valloaders, testloader

class Stackoverflow():
    def __init__(self, NUM_CLIENTS, IID,BATCH_SIZE):
        self.NUM_CLIENTS= NUM_CLIENTS
        self.BATCH_SIZE= BATCH_SIZE
        self.IID= IID
        
    def non_iid_create_val_loaders(self):
        valloaders = []
        for i,f in enumerate(os.listdir("data/stackoverflow/test_np/")):
            feature_data = np.load('data/stackoverflow/test_np/'+f)['x'].astype(np.float32)
            target_labels = np.load('data/stackoverflow/test_np/'+f)['y']
            ds_val = CustomDataset(feature_data,target_labels)
            valloaders.append(DataLoader(ds_val, batch_size=self.BATCH_SIZE))
        return valloaders

    def non_iid_create_train_loaders(self):
        trainloaders = []
        for i,f in enumerate(os.listdir("data/stackoverflow/train_np/")):
            feature_data = np.load('data/stackoverflow/train_np/'+f)['x'].astype(np.float32)
            target_labels = np.load('data/stackoverflow/train_np/'+f)['y']
            ds_train = CustomDataset(feature_data,target_labels)
            trainloaders.append(DataLoader(ds_train, batch_size=self.BATCH_SIZE, shuffle=True))
        return trainloaders

    def create_iid_loaders(self):
        trainset = {'x': [], 'y': []}
        
        for i,f in enumerate(os.listdir("data/stackoverflow/train_np/")):
            trainset['x'].append(np.load('data/stackoverflow/train_np/'+f)['x'].astype(np.float32))
            trainset['y'].append(np.load('data/stackoverflow/train_np/'+f)['y'])

        for i,f in enumerate(os.listdir("data/stackoverflow/test_np/")):
            trainset['x'].append(np.load('data/stackoverflow/test_np/'+f)['x'].astype(np.float32))
            trainset['y'].append(np.load('data/stackoverflow/test_np/'+f)['y'])

        trainset['x'] = np.concatenate(trainset['x'], axis=0)
        trainset['y'] = np.concatenate(trainset['y'])
        
        # Shuffle the data 
        trainset['x']= trainset['x'][:-1]
        trainset['y']= trainset['y'][:-1]
        np.random.seed(123)
        perm = np.random.permutation(len(trainset['y']))
        trainset['x'], trainset['y'] = trainset['x'][perm], trainset['y'][perm]
        trainset = CustomDataset(trainset['x'],trainset['y'])
        
        # Create dataloaders from shuffled data 
        partition_size = len(trainset) // self.NUM_CLIENTS
        remainder = len(trainset) % self.NUM_CLIENTS
        lengths = [partition_size] * self.NUM_CLIENTS
        if remainder:
            lengths.append(remainder)
        datasets = random_split(trainset, lengths, torch.Generator().manual_seed(42))
        if remainder:
            datasets = datasets[:-1]
        trainloaders = []
        valloaders = []
        for ds in datasets:
            len_val = len(ds) // 10  
            len_train = len(ds) - len_val
            lengths = [len_train, len_val]
            ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))
            trainloaders.append(DataLoader(ds_train, batch_size=self.BATCH_SIZE, shuffle=True))
            valloaders.append(DataLoader(ds_val, batch_size=self.BATCH_SIZE))
        return trainloaders,valloaders, valloaders[0]

    def load_datasets(self):
        if self.IID:
            trainloaders, valloaders, testloader = self.create_iid_loaders()
        else:
            trainloaders = self.non_iid_create_train_loaders()
            valloaders = self.non_iid_create_val_loaders()
            testloader = valloaders[0]
        return trainloaders, valloaders, testloader

class GLDV2():
    def __init__(self, NUM_CLIENTS, IID,BATCH_SIZE):
        self.NUM_CLIENTS= NUM_CLIENTS
        self.BATCH_SIZE= BATCH_SIZE
        self.IID= IID
        self.test_directory = "/directory/for_data/gld23k_build/test_np/"
        self.train_directory = "/directory/for_data/gld23k_build/train_np/" 
        
    def non_iid_create_val_loaders(self):
        valloaders = []
        location = self.test_directory
        for i,f in enumerate(os.listdir(location)): 
            feature_data = np.load(location+f)['x'].astype(np.float32).transpose(0,3,1,2)
            target_labels = np.load(location+f)['y'].squeeze(1)
            print(f"Validation feature data has shape {np.load(location+f)['x'].astype(np.float32).shape}, client {i}, label length {len(target_labels)}")
            ds_val = CustomDataset(feature_data,target_labels)
            valloaders.append(DataLoader(ds_val, batch_size=self.BATCH_SIZE))
        return valloaders

    def non_iid_create_train_loaders(self):
        trainloaders = []
        location = self.train_directory
        for i,f in enumerate(os.listdir(location)):
            feature_data = np.load(location+f)['x'].astype(np.float32).transpose(0,3,1,2)
            target_labels = np.load(location+f)['y'].squeeze(1)
            print(f"Train feature data has shape {np.load(location+f)['x'].astype(np.float32).shape}, client {i}, label length {len(target_labels)}")
            ds_train = CustomDataset(feature_data,target_labels)
            trainloaders.append(DataLoader(ds_train, batch_size=self.BATCH_SIZE, shuffle=True))
        return trainloaders

    def create_iid_loaders(self):
        # Merge all the clients for trainset
        trainset = {'x': [], 'y': []}
        
        for i,f in enumerate(os.listdir(self.train_directory)):
            trainset['x'].append(np.load(self.train_directory+f)['x'].astype(np.float32).transpose(0,3,1,2))
            trainset['y'].append(np.load(self.train_directory+f)['y'].squeeze(1))

        for i,f in enumerate(os.listdir(self.test_directory)):
            trainset['x'].append(np.load(self.test_directory+f)['x'].astype(np.float32).transpose(0,3,1,2))
            trainset['y'].append(np.load(self.test_directory+f)['y'].squeeze(1))

        trainset['x'] = np.concatenate(trainset['x'], axis=0)
        trainset['y'] = np.concatenate(trainset['y'])
        
        # Shuffle the data 
        trainset['x']= trainset['x'][:-1]
        trainset['y']= trainset['y'][:-1]
        np.random.seed(123)
        perm = np.random.permutation(len(trainset['y']))
        trainset['x'], trainset['y'] = trainset['x'][perm], trainset['y'][perm]
        trainset = CustomDataset(trainset['x'],trainset['y'])

        # Create dataloaders
        partition_size = len(trainset) // self.NUM_CLIENTS
        remainder = len(trainset) % self.NUM_CLIENTS
        lengths = [partition_size] * self.NUM_CLIENTS
        if remainder:
            lengths.append(remainder)
        datasets = random_split(trainset, lengths, torch.Generator().manual_seed(42))
        if remainder:
            datasets = datasets[:-1]
        trainloaders = []
        valloaders = []
        
        for ds in datasets:
            len_val = len(ds) // 10  
            len_train = len(ds) - len_val
            lengths = [len_train, len_val]
            ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))
            trainloaders.append(DataLoader(ds_train, batch_size=self.BATCH_SIZE, shuffle=True))
            valloaders.append(DataLoader(ds_val, batch_size=self.BATCH_SIZE))
        return trainloaders,valloaders, valloaders[0]

    def load_datasets(self):
        if self.IID:
            trainloaders, valloaders, testloader = self.create_iid_loaders()
        else:
            trainloaders = self.non_iid_create_train_loaders()
            valloaders = self.non_iid_create_val_loaders()
            testloader = valloaders[0]
        return trainloaders, valloaders, testloader